
from .mng_json import json_manager, helpSgltn, TroubleSgltn

class Tagger:
    def __init__(self)-> None:
        self.trbl = TroubleSgltn()
        self.j_mngr = json_manager()
        self.help_data = helpSgltn()

    @staticmethod
    def join_punct(text: str, end_char:str=""):
        #Utility to create proper ending puctuation for text joins and ends
        text = text.rstrip()
        if text.endswith((':', ';', '-', ',', '.')):
            return ' '  # Add a space if special punctuation ends the text
        else:
            if end_char:
                return end_char + ' '  # Add the passed end character and space otherwise
            else:
                return ' ' #User's don't want the app to add a comma to their tags

    @staticmethod
    def enhanced_text_placement(generated_text:str, b_tags:str="", m_tags:str="", e_tags:str="", pref_periods:bool=False):

        """
        Enhances text placement within a the generated text block based on user input, specified delimiters and markup.
        Text prefaced with "*" is placed at the beginning of the block, while text prefaced with "**"
        is placed in the middle, immediately following a period or comma. Unmarked text is added to the end
        of the block by default. This feature requires users to delimit each text segment intended for
        placement with a specified delimiter (default is a pipe '|'), regardless of its intended position.

        Args:
            generated_text (str): The existing text block generated by the LLM, where new text will be integrated.
            user_input (str): Delimited text input from the user containing potential markers for special placement.
            delimiter (str): The character used to separate different sections of the user input for specific placement.

        Returns:
            str: The updated text block with user input integrated at specified positions.

        """
            
            # Initialize default sections
        b_tags.strip()
        m_tags.strip()
        e_tags.strip()

        if not b_tags and not m_tags and not e_tags:
            return generated_text + Tagger.join_punct(generated_text,'.')
        
        end_text, beginning_text, middle_text = '', '', ''
        beginning_text = b_tags
        middle_text = m_tags
        end_text = e_tags
    

        mid_punct = Tagger.join_punct(middle_text)
        end_punct = Tagger.join_punct(end_text)
        begin_punct = Tagger.join_punct(beginning_text)

        # Integrate middle text based on punctuation logic in the generated_text
        commas = generated_text.count(',')
        periods = generated_text.count('.')
        #punct_count = max(commas, periods)
        search_punct = []
        if pref_periods and periods > 1:
            punct_count = periods
            search_punct = ["."]
        else:
            punct_count = commas + periods
            search_punct = [",", "."]
        
        if middle_text:
            
            if punct_count == 0:
                end_text = end_punct.join([end_text, middle_text]) if end_text else middle_text
            elif punct_count <= 2:
                # Look for the first instance of either a comma or a period
                first_punctuation_index = len(generated_text)  # Default to the end of the string
                for char in search_punct:  # Check for both commas and periods
                    index = generated_text.find(char)
                    if 0 <= index < first_punctuation_index:  # Check if this punctuation occurs earlier
                        first_punctuation_index = index

                # Insert the middle text after the first punctuation found, if any
                if first_punctuation_index < len(generated_text):
                    insert_index = first_punctuation_index + 1  # Position right after the punctuation
                    generated_text = generated_text[:insert_index] + ' ' + middle_text + mid_punct + generated_text[insert_index:]
            else:
                # Insert at the midpoint punctuation
                target = punct_count // 2
                count = 0
                insert_index = 0
                for i, char in enumerate(generated_text):
                    if char in search_punct:
                        count += 1
                        if count == target:
                            insert_index = i + 2  # After the punctuation and space
                            break
                generated_text = generated_text[:insert_index] + middle_text + mid_punct + generated_text[insert_index:]
        
        # Integrate beginning and end text
        if beginning_text:
            generated_text = beginning_text + begin_punct + generated_text
        if end_text:
            generated_text += Tagger.join_punct(generated_text,'.') + end_text
        
        return generated_text.strip(', ')  # Ensure no leading or trailing commas 
    


    @classmethod
    def INPUT_TYPES(cls):

        return {
            "required": {
                "Beginning_tags": ("STRING", {"multiline": True}),
                "Middle_tags": ("STRING", {"multiline": True}),
                "Prefer_middle_tag_after_period": ("BOOLEAN", {"default": True}),               
                "End_tags": ("STRING", {"multiline": True})                                      
            },
            "optional": {
                "text": ("STRING", {"forceInput": True, "multiline": True})
            },

            "hidden": {
                "unique_id": "UNIQUE_ID",
            },
        } 

    RETURN_TYPES = ("STRING","STRING","STRING")
    RETURN_NAMES = ("tagged_text", "help","troubleshooting")

    FUNCTION = "gogo"

    OUTPUT_NODE = False

    CATEGORY = "Plush/Utils"

    def gogo(self, unique_id, text, Beginning_tags, Middle_tags, End_tags, Prefer_middle_tag_after_period)-> tuple:
        _help = self.help_data.tagger_help
        if unique_id:
            self.trbl.reset('Tagger, Node #' + unique_id)
        else:
            self.trbl.reset('Tagger')

        if Middle_tags:
            if Prefer_middle_tag_after_period:
                message = "Prefer middle tags to be inserted after a period."
            else:
                message = "Allow middle tags to be inserted after either a period or a comma."

            self.j_mngr.log_events(message,
                                   is_trouble=True)

        output = Tagger.enhanced_text_placement(text, Beginning_tags, Middle_tags, End_tags,Prefer_middle_tag_after_period)

        self.j_mngr.log_events("Inserting tags into text block.",
                               is_trouble=True)
                               
        return(output, _help, self.trbl.get_troubles())
    


class mulTextSwitch:

    def __init__(self):
        pass

    @classmethod
    def INPUT_TYPES(cls):

        return {
            "required": {
                "active_input": ("INT", {"max": 3, "min": 1, "step": 1, "default": 1, "display": "number"})
            },
            "optional": {
                "Input_1": ("STRING", {"multiline": True, "forceInput": True}),
                "Input_2": ("STRING", {"multiline": True, "forceInput": True}),
                "Input_3": ("STRING", {"multiline": True, "forceInput": True}),
            }
        } 
    
    RETURN_TYPES = ("STRING", )
    RETURN_NAMES = ("Multiline Text", )

    FUNCTION = "gogo"

    OUTPUT_NODE = False

    CATEGORY = "Plush/Utils"

    def gogo(self, active_input, Input_1=None, Input_2=None, Input_3=None):

        ret_text = ""

        if active_input == 1 and Input_1:
            ret_text = Input_1
        elif active_input == 2 and Input_2:
            ret_text = Input_2
        elif active_input ==3 and Input_3:
            ret_text = Input_3

        if not ret_text:
            raise Exception ("Missing text input, check selction")

        return (ret_text, )
    


class ImgTextSwitch:

    def __init__(self):
        pass

    @classmethod
    def INPUT_TYPES(cls):

        return {
            "required": {
                "active_input": ("INT", {"max": 3, "min": 1, "step": 1, "default": 1, "display": "number"})
            },
            "optional": {
                "Text_1": ("STRING", {"multiline": True, "forceInput": True}),
                "Image_1" : ("IMAGE", {"default": None}),
                "Text_2": ("STRING", {"multiline": True, "forceInput": True}),
                "Image_2" : ("IMAGE", {"default": None}),
                "Text_3": ("STRING", {"multiline": True, "forceInput": True}),
                "Image_3" : ("IMAGE", {"default": None})
            }
        } 
    
    RETURN_TYPES = ("STRING", "IMAGE")
    RETURN_NAMES = ("Multiline Text","Image" )

    FUNCTION = "gogo"

    OUTPUT_NODE = False

    CATEGORY = "Plush/Utils"

    def gogo(self, active_input, Text_1=None, Image_1=None, Text_2=None, Image_2=None, Text_3=None, Image_3=None):

        ret_text = ""
        ret_img = None

        if active_input == 1:
            ret_text = Text_1
            ret_img = Image_1
        elif active_input == 2:
            ret_text = Text_2
            ret_img = Image_2
        elif active_input ==3:
            ret_text = Text_3
            ret_img = Image_3

        if not ret_text and not ret_img:
            raise Exception ("Missing text and image input, check selction")

        return (ret_text, ret_img)
    
NODE_CLASS_MAPPINGS = {
    "mulTextSwitch": mulTextSwitch,
    "ImgTextSwitch": ImgTextSwitch,
    "Tagger": Tagger
}

NODE_DISPLAY_NAME_MAPPINGS = {
"mulTextSwitch": "MultiLine Text Switch",
"ImgTextSwitch": "Image & Text Switch",
"Tagger": "Tagger"
}
